import scipy.io as sio
import torch
import csv
import matplotlib.pyplot as plt
from tqdm import tqdm
import imshift
from class_Angular_Spectrum import Angular_Spectrum
import os
import time


#device
os.environ['CUDA_VISIBLE_DEVICES'] = "0"
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

start_time = time.time()  # start time

#image size
N1 = 512
N2 = 512

#ptychography parameters
pxsize = 50e-3 #mm
wavlen = 116e-3 #mm
dist = 20 #mm
radius = 130/2 #pixeals


# object translation positions
K1 = 8
K2 = 8
K = K1*K2
step = 1.2
shifts_1_range = torch.linspace(-step*(K1-1)/2, step*(K1-1)/2, K1)
shifts_2_range = torch.linspace(-step*(K2-1)/2, step*(K2-1)/2, K2)
shifts_1, shifts_2 = torch.meshgrid(shifts_1_range, shifts_2_range, indexing='ij')
shifts_1=shifts_1.t()
shifts_2=shifts_2.t()

pos_error = 0.005   #estimation error for the probe positions (mm)
perturbs_1 = pos_error*torch.rand_like(shifts_1) - pos_error/2
perturbs_2 = pos_error*torch.rand_like(shifts_2) - pos_error/2
shifts_1 = shifts_1 + perturbs_1
shifts_2 = shifts_2 + perturbs_2

shifts_1=shifts_1.reshape(K,1).to(device)
shifts_2=shifts_2.reshape(K,1).to(device)

#Read data
mat_contents = sio.loadmat('image3_8x8.mat')
y = mat_contents['tensor']
y = torch.from_numpy(y).to(device)
y = y.double()
y = torch.complex(y, torch.zeros_like(y, dtype=torch.float64))


#probe initialization
probe_est = sio.loadmat('probe_est_initial.mat')
probe_est = probe_est['probe_est']
probe_est = torch.from_numpy(probe_est).to(device)
imag_part = torch.zeros((512,512),dtype=torch.float64).to(device)
probe_est = torch.complex(probe_est,imag_part)

#object initialization
obj_est = torch.ones((N1,N2),dtype=torch.complex128).to(device)

#running iterations
n_iters = 5000


#Angular Spectrum Method
ASM1=Angular_Spectrum(wavlen, N1, pxsize, dist)
ASM2=Angular_Spectrum(wavlen, N1, pxsize, -dist)


dir_save = './results/ePIE/'
os.makedirs(dir_save, exist_ok=True)
dir_save1 = dir_save + '1cycle/'
os.makedirs(dir_save1, exist_ok=True)
dir_save2 = dir_save + '100cycle/'
os.makedirs(dir_save2, exist_ok=True)

header = ['no','cycle time','total time','read time']


#ePIE algorithms
data_read_time = time.time()-start_time
alpha_1 = 1
alpha_2 = 1
file_time = 0
for cycle in tqdm(range(1, n_iters + 1)):

    iter_start = time.time()
    for k in torch.randperm(K):
        exit_wave = probe_est * imshift.imshift(obj_est, shifts_1[k] / pxsize, shifts_2[k] / pxsize,device)
        u_est = ASM1(exit_wave)
        u_est = torch.sqrt(y[k,:,:]) * torch.exp(1j * torch.angle(u_est))
        exit_wave_new = ASM2(u_est)

        obj_est_new = (obj_est + alpha_1 * torch.conj(imshift.imshift(probe_est, -shifts_1[k] / pxsize, -shifts_2[k] / pxsize,device) ) / (torch.max(torch.abs(probe_est) ** 2)+0.02) * imshift.imshift(exit_wave_new - exit_wave, -shifts_1[k] / pxsize, -shifts_2[k] / pxsize,device))

        if cycle >= 1:
            probe_est_new = probe_est + alpha_2 * torch.conj(imshift.imshift(obj_est, -shifts_1[k] / pxsize, shifts_2[k] / pxsize,device)) / (torch.max( torch.abs(obj_est) ** 2) +0.02) * (exit_wave_new - exit_wave)
            probe_est = probe_est_new

        obj_est = obj_est_new


    if (cycle <= 100):
        amp_np = torch.abs(obj_est).cpu().numpy()
        pha_np = torch.angle(obj_est).cpu().numpy()
        amp_np_pro = torch.abs(probe_est).cpu().numpy()
        pha_np_pro = torch.angle(probe_est).cpu().numpy()
        filename = dir_save1 + 'ePIE' +  '_cycle' + (str(cycle).zfill(6)) + '_results.mat'
        sio.savemat(filename, {'output_obj_A': amp_np, 'output_obj_P': pha_np,
                               'output_pro_A': amp_np_pro, 'output_pro_P': pha_np_pro})


    iter_time = time.time() - iter_start
    file_time += iter_time
    values = [cycle, iter_time ,file_time,data_read_time]

    path_csv = dir_save + 'ePIE' + '_sample' + '.csv'
    if os.path.isfile(path_csv) == False:
        file = open(path_csv, 'w', newline='')
        writer = csv.writer(file)
        writer.writerow(header)
        writer.writerow(values)
    else:
        file = open(path_csv, 'a', newline='')
        writer = csv.writer(file)
        writer.writerow(values)
    file.close()


    if (cycle == 1) or (cycle == 10) or (cycle == 50) or ((cycle) % 100 == 0):
        plt.subplot(2, 4, 1)
        amp_np = torch.abs(obj_est[129:384,129:384]).cpu().numpy()
        plt.imshow(amp_np, cmap='gray')
        plt.title('lable (amp)')
        plt.colorbar(fraction=0.05, pad=0.05)
        plt.axis('off')
        plt.subplot(2, 4, 2)
        pha_np = torch.angle(obj_est[129:384,129:384]).cpu().numpy()
        plt.imshow(pha_np, cmap='inferno')
        plt.title('lable (pha)')
        plt.colorbar(fraction=0.05, pad=0.05)
        plt.axis('off')
        plt.subplot(2, 4, 3)
        amp_np = torch.abs(obj_est).cpu().numpy()
        plt.imshow(amp_np, cmap='gray')
        plt.title('lable (amp)')
        plt.colorbar(fraction=0.05, pad=0.05)
        plt.axis('off')
        plt.subplot(2, 4, 4)
        pha_np = torch.angle(obj_est).cpu().numpy()
        plt.imshow(pha_np, cmap='inferno')
        plt.title('lable (pha)')
        plt.colorbar(fraction=0.05, pad=0.05)
        plt.axis('off')

        plt.subplot(2, 4, 5)
        amp_np_pro = torch.abs(probe_est[129:384,129:384]).cpu().numpy()
        plt.imshow(amp_np_pro, cmap='gray')
        plt.title('cycle' + (str(cycle).zfill(6)))
        plt.colorbar(fraction=0.05, pad=0.05)
        plt.axis('off')
        plt.subplot(2, 4, 6)
        pha_np_pro = torch.angle(probe_est[129:384,129:384]).cpu().numpy()
        plt.imshow(pha_np_pro, cmap='inferno')
        # plt.title('probe (pha)')
        plt.colorbar(fraction=0.05, pad=0.05)
        plt.axis('off')
        plt.subplot(2, 4, 7)
        amp_np_pro = torch.abs(probe_est).cpu().numpy()
        plt.imshow(amp_np_pro, cmap='gray')
        plt.title('cycle' + (str(cycle).zfill(6)))
        plt.colorbar(fraction=0.05, pad=0.05)
        plt.axis('off')
        plt.subplot(2, 4, 8)
        pha_np_pro = torch.angle(probe_est).cpu().numpy()
        plt.imshow(pha_np_pro, cmap='inferno')
        # plt.title('probe (pha)')
        plt.colorbar(fraction=0.05, pad=0.05)
        plt.axis('off')

        filename_plt = dir_save2 + 'ePIE' +  '_cycle' + (
            str(cycle).zfill(6)) + '_results.png'
        plt.savefig(filename_plt)
        plt.show()


        amp_np = torch.abs(obj_est).cpu().numpy()
        pha_np = torch.angle(obj_est).cpu().numpy()
        amp_np_pro = torch.abs(probe_est).cpu().numpy()
        pha_np_pro = torch.angle(probe_est).cpu().numpy()
        filename = dir_save2 + 'ePIE' +  '_cycle' + (str(cycle).zfill(6)) + '_results.mat'
        sio.savemat(filename, {'output_obj_A': amp_np, 'output_obj_P': pha_np,
                               'output_pro_A': amp_np_pro, 'output_pro_P': pha_np_pro})








